"""
OpenGauss 数据库后端 - 基于 PostgreSQL，但针对 OpenGauss 进行了优化
"""
from django.db.backends.postgresql import base
from django.core.exceptions import ImproperlyConfigured
from django.utils.asyncio import async_unsafe

_EXTRA_CONN_OPTION_KEYS = {
    "sslrootcert",
    "sslcert",
    "sslkey",
    "sslcrl",
    "sslcompression",
    "sslpassword",
    "target_session_attrs",
    "service",
    "options",
    "keepalives",
    "keepalives_idle",
    "keepalives_interval",
    "keepalives_count",
    "gssencmode",
    "channel_binding",
}

# 导入自定义的模块
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .features import DatabaseFeatures
from .schema import DatabaseSchemaEditor
from .client import DatabaseClient
from .creation import DatabaseCreation


class DatabaseWrapper(base.DatabaseWrapper):
    """
    OpenGauss 数据库连接包装器
    提供完整的数据库后端实现
    """
    vendor = 'opengauss'
    display_name = 'OpenGauss'
    
    # 数据类型映射
    data_types = {
        'AutoField': 'serial',
        'BigAutoField': 'bigserial',
        'BinaryField': 'bytea',
        'BooleanField': 'boolean',
        'CharField': 'varchar(%(max_length)s)',
        'DateField': 'date',
        'DateTimeField': 'timestamp with time zone',
        'DecimalField': 'numeric(%(max_digits)s, %(decimal_places)s)',
        'DurationField': 'interval',
        'FileField': 'varchar(%(max_length)s)',
        'FilePathField': 'varchar(%(max_length)s)',
        'FloatField': 'double precision',
        'IntegerField': 'integer',
        'BigIntegerField': 'bigint',
        'IPAddressField': 'inet',
        'GenericIPAddressField': 'inet',
        'JSONField': 'jsonb',
        'NullBooleanField': 'boolean',
        'OneToOneField': 'integer',
        'PositiveIntegerField': 'integer',
        'PositiveSmallIntegerField': 'smallint',
        'SlugField': 'varchar(%(max_length)s)',
        'SmallAutoField': 'smallserial',
        'SmallIntegerField': 'smallint',
        'TextField': 'text',
        'TimeField': 'time',
        'UUIDField': 'uuid',
    }
    
    # 数据类型后缀，防止生成 GENERATED BY DEFAULT 语法
    data_type_check_constraints = {
        'PositiveIntegerField': '%(column)s >= 0',
        'PositiveSmallIntegerField': '%(column)s >= 0',
    }
    
    # 使用自定义的模块类
    introspection_class = DatabaseIntrospection
    ops_class = DatabaseOperations
    features_class = DatabaseFeatures
    schema_editor_class = DatabaseSchemaEditor
    client_class = DatabaseClient
    creation_class = DatabaseCreation
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 确保驱动库可用
        if self.settings_dict['ENGINE'] == 'django_opengauss':
            self._ensure_connection_library()

    def _ensure_connection_library(self):
        """
        优先检测 psycopg2，如不可用则退回 psycopg v3
        """
        try:
            import psycopg2  # noqa: F401
            self._driver_tag = "psycopg2"
        except ImportError:
            try:
                import psycopg  # noqa: F401
                self._driver_tag = "psycopg3"
            except ImportError as exc:
                raise ImproperlyConfigured(
                    "Error loading psycopg module: install psycopg2-binary or psycopg"
                ) from exc

    def get_connection_params(self):
        """
        获取数据库连接参数
        只添加 psycopg2 支持的连接参数
        """
        options = self.settings_dict.get('OPTIONS', {})
        conn_params = super().get_connection_params()
        # 移除 psycopg 无法识别但供驱动内部使用的配置
        cursor_itersize = conn_params.pop('cursor_itersize', None)
        if cursor_itersize is not None:
            options.setdefault('cursor_itersize', cursor_itersize)

        # 添加 OpenGauss 特定参数
        
        # 应用名称
        if 'application_name' in options:
            conn_params['application_name'] = options['application_name']
        elif 'application_name' not in conn_params:
            conn_params['application_name'] = 'Django-OpenGauss'
        
        # 连接超时
        if 'connect_timeout' in options:
            conn_params['connect_timeout'] = options['connect_timeout']
        
        # SSL 及网络配置
        if 'sslmode' in options:
            conn_params['sslmode'] = options['sslmode']

        for key in _EXTRA_CONN_OPTION_KEYS:
            if key in options:
                conn_params[key] = options[key]

        # 允许一次性传入额外参数
        extra_params = options.get('extra_params')
        if isinstance(extra_params, dict):
            for key, value in extra_params.items():
                conn_params.setdefault(key, value)

        return conn_params
    
    def init_connection_state(self):
        """
        初始化连接状态
        设置 OpenGauss 会话参数
        """
        # 调用祖父类的方法，跳过 PostgreSQL 的版本检查
        super(base.DatabaseWrapper, self).init_connection_state()
        
        # 设置时区
        if self.settings_dict.get('USE_TZ'):
            tz = self.settings_dict.get('TIME_ZONE')
            if tz:
                with self.cursor() as cursor:
                    cursor.execute("SET TIME ZONE %s", [tz])
    
    def check_database_version_supported(self):
        """
        跳过 PostgreSQL 版本检查
        OpenGauss 有自己的版本号体系
        """
        # OpenGauss 版本格式与 PostgreSQL 不同
        # 例如：OpenGauss 2.0.0 而不是 PostgreSQL 12.x
        # 因此跳过版本检查
        pass
    
    @async_unsafe
    def create_cursor(self, name=None):
        """
        创建数据库游标
        支持命名游标用于大结果集
        """
        cursor = super().create_cursor(name)
        
        # 设置游标选项
        if name:
            # 服务器端游标，适合大数据集
            cursor.itersize = self.settings_dict.get('OPTIONS', {}).get('cursor_itersize', 2000)
        
        return cursor
    
    def is_usable(self):
        """
        检查数据库连接是否可用
        """
        try:
            with self.cursor() as cursor:
                cursor.execute("SELECT 1")
            return True
        except Exception:
            return False
    
    def _set_autocommit(self, autocommit):
        """
        设置自动提交模式
        OpenGauss 支持自动提交控制
        """
        with self.wrap_database_errors:
            self.connection.autocommit = autocommit
    
    def schema_editor(self, *args, **kwargs):
        """
        返回模式编辑器实例
        """
        return DatabaseSchemaEditor(self, *args, **kwargs)
